import os
import sys
import time
import subprocess
import logging
import pickle
import os.path
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from threading import Lock

from settings import WORK_DIR


logger = logging.getLogger('audit')
ScannerConfig = namedtuple('ScannerConfig', 'timestamp scan_params force')


def get_scanner_config() -> ScannerConfig:
    f_cfg = WORK_DIR / '.scanner.cfg'
    if not os.path.isfile(f_cfg):
        return None
    with open(f_cfg, 'rb') as fp:
        return pickle.load(fp)


def update_scanner_config(cfg: ScannerConfig):
    f_cfg = WORK_DIR / '.scanner.cfg'
    with open(f_cfg, 'wb') as fp:
        pickle.dump(cfg, fp)


class ConScanner:
    def __init__(self, scanner_dir, scanner_name=None, scan_params=None):
        if not os.path.isdir(scanner_dir):
            raise NotADirectoryError(scanner_dir)
        if not scanner_name:
            scanner_name = 'lame.scanner.exe' if sys.platform == 'win32' else 'lame.scanner'
        self.scanner_dir = scanner_dir
        self.scanner_name = scanner_name
        self.scan_params = scan_params
        self.scanner_cfg = get_scanner_config()
        if self.scanner_cfg:
            self.scan_params = self.scanner_cfg.scan_params

        self.scanner_process = None
        self.thread_pool = None
        self.lock = Lock()

    def run(self):
        with self.lock:
            exe_path = os.path.join(self.scanner_dir, self.scanner_name)
            if not os.path.isfile(exe_path):
                raise FileNotFoundError(exe_path)
            f_vlib = os.path.join(self.scanner_dir, 'malware.rmd')
            if not os.path.isfile(f_vlib):
                raise FileNotFoundError(f_vlib)
            params = [exe_path, '-daemon']
            if self.scan_params:
                params.extend(self.scan_params.split())
            p = subprocess.Popen(params, stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.scanner_dir)
            begin = time.time()
            p.stdout.flush()
            while True:
                line = p.stdout.readline().strip()
                if line.endswith(b'entering daemon mode...'):
                    break
                print(line)
                if not line:
                    if (time.time() - begin) > 60:
                        logger.warning('[{}]扫描器启动失败，已超时'.format(os.getpid()))
                        raise Exception('run scanner failed, {}'.format(exe_path))
                    logger.info('[{}]扫描器启动中...'.format(os.getpid()))
                    time.sleep(5)
            if not self._is_running(p):
                raise Exception('run scanner failed, {} {}'.format(exe_path, self.scan_params))
            logger.info('[{}]scanner is running... {} {}'.format(os.getpid(), exe_path, self.scan_params))
            self.scanner_process = p
            self.thread_pool = ThreadPoolExecutor(max_workers=1)  # 同一时间只能处理一个任务

    def close(self):
        with self.lock:
            if self.scanner_process:
                self._kill(self.scanner_process)
            if self.thread_pool:
                self.thread_pool.shutdown()
            logger.info('[{}]scanner closed... {}'.format(os.getpid(), self.scanner_dir))
            self.scanner_process, self.thread_pool = None, None

    def scan_file(self, f_path, timeout=30):
        self.check_scanner_cfg()
        if not self.scanner_process or self.scanner_process.poll() is not None:
            self.run()
        overtime = False
        future = self.thread_pool.submit(self._scan, self.scanner_process, f_path)
        try:
            result = future.result(timeout=timeout)
        except TimeoutError:
            overtime, result = True, ''
            self.close()  # 如果扫描超时，强行结束扫描进程
        if result:
            result = result.split(':', 1)[-1]
        logger.info('[{}]{}, {}'.format(os.getpid(), f_path, '|Timeout|' if overtime else result))
        return overtime, result

    def check_scanner_cfg(self):
        scanner_cfg = get_scanner_config()
        if not scanner_cfg:
            return
        if not scanner_cfg.force and self.scan_params == scanner_cfg.scan_params:
            return
        if self.scanner_cfg and self.scanner_cfg.timestamp >= scanner_cfg.timestamp:
            return
        logger.info(f'[{os.getpid()}]force restart or scanner config changed, restart scanner... {scanner_cfg}')
        self.scan_params = scanner_cfg.scan_params
        self.close()
        self.run()
        self.scanner_cfg = scanner_cfg

    def upgrade_lib(self, lib_name):
        name = 'libup.exe' if sys.platform == 'win32' else 'libup'
        exe_path = os.path.join(self.scanner_dir, name)
        if not os.path.isfile(exe_path):
            logger.warning(f'[{os.getpid()}]libup not found, {exe_path}')
            return False, f'upgrade failed: libup not found'
        params = [exe_path, '-xml', lib_name]
        with self.lock:
            logger.info(f'[{os.getpid()}]upgrade lib... {params}')
            p = subprocess.Popen(params, stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.scanner_dir)
            output, err = p.communicate()
            if p.returncode == 0:
                logger.info(f'[{os.getpid()}]upgrade lib success')
                return True, f'upgrade success'
            if p.returncode == 1:
                logger.info(f'[{os.getpid()}]upgrade lib failed, lib is up to date')
                return False, f'upgrade failed: lib is up to date'
            err_msg = (err or output).decode(errors="ignore").strip()
            logger.warning(f'[{os.getpid()}]upgrade lib failed, {err_msg}')
            return False, f'upgrade failed: {err_msg}'

    @staticmethod
    def _is_running(p):
        if not p:
            return False
        if p.poll() is None:
            return True
        return False

    @staticmethod
    def _kill(p):
        if not p:
            return
        for i in range(3):
            try:
                p.terminate()
                p.wait(timeout=10)
            except subprocess.TimeoutExpired:
                continue

    @staticmethod
    def _scan(p, f_path):
        if not os.path.isfile(f_path):
            return None
        if not p or p.poll() is not None:
            raise Exception('scanner is not running')
        end_str = '==-=='
        p.stdin.write('scan {} {}\n'.format(end_str, f_path).encode('gbk'))
        p.stdin.flush()
        lines = []
        while True:
            line = p.stdout.readline().strip()
            # print(line)
            if not line or line == end_str.encode():
                break
            try:
                lines.append(line.decode('gbk'))
            except UnicodeDecodeError:
                lines.append(line.decode('utf-8', errors='ignore'))
        if lines:
            arr = [s for s in lines[0].split('\t') if s.strip()]
            if len(arr) > 1:
                v_name = arr[1].strip()
                if v_name.lower() == 'ok':
                    return ''
                return v_name
            return ''
        return None


# class LScanner:
#     def __init__(self, scanner_dir):
#         self.lame, self.dbf = None, None
#         _param_lst = []
#
#         _lame = LameScanner(scanner_dir)
#
#         _dbf = VirusDb(scanner_dir)
#         if not _dbf.OpenVdb(None):
#             logger.warning('open vdb failed')
#             return
#
#         for _param in _param_lst:
#             _lame.SetParam(_param)
#
#         if not _lame.Load(_dbf):
#             _dbf.CloseVdb()
#             logger.warning('load lame failed')
#             return
#         self.lame, self.dbf = _lame, _dbf
#
#     def close(self):
#         if self.lame:
#             self.lame.close()
#         if self.dbf:
#             self.dbf.close()
#         self.lame, self.dbf = None, None
#
#     def scan_file(self, f_path):
#         if not self.lame:
#             raise Exception('引擎没有初始化')
#         return self.lame.scan_file(f_path)
